{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 0
   },
   "source": [
    "# Concise Implementation of Recurrent Neural Networks\n",
    ":label:`sec_rnn-concise`\n",
    "\n",
    "While :numref:`sec_rnn_scratch` was instructive to see how RNNs are implemented,\n",
    "this is not convenient or fast.\n",
    "This section will show how to implement the same language model more efficiently\n",
    "using functions provided by high-level APIs\n",
    "of a deep learning framework.\n",
    "We begin as before by reading the time machine dataset.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load ../utils/djl-imports\n",
    "%load ../utils/plot-utils\n",
    "%load ../utils/PlotUtils.java\n",
    "\n",
    "%load ../utils/Accumulator.java\n",
    "%load ../utils/Animator.java\n",
    "%load ../utils/Functions.java\n",
    "%load ../utils/StopWatch.java\n",
    "%load ../utils/Training.java\n",
    "%load ../utils/timemachine/Vocab.java\n",
    "%load ../utils/timemachine/RNNModelScratch.java\n",
    "%load ../utils/timemachine/TimeMachine.java"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import ai.djl.training.dataset.Record;"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "NDManager manager = NDManager.newBaseManager();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Creating a Dataset in DJL"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In DJL, the ideal and concise way of dealing with datasets, is to use the built-in datasets that can easily wrap around existing NDArrays or to create your own dataset that extends from the `RandomAccessDataset` class. For this section, we will be implementing our own. For more information on creating your own dataset in DJL, you can refer to: https://djl.ai/docs/development/how_to_use_dataset.html\n",
    "\n",
    "Our implementation of `TimeMachineDataset` will be a concise replacement of the `SeqDataLoader` class previously created. Using a dataset in DJL format, will allow us to use already built-in functions so we don't have to implement most things from scratch. We have to implement a Builder, a prepare function which will contain the process to save the data to the TimeMachineDataset object, and finally a get function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "public static class TimeMachineDataset extends RandomAccessDataset {\n",
    "\n",
    "    private Vocab vocab;\n",
    "    private NDArray data;\n",
    "    private NDArray labels;\n",
    "    private int numSteps;\n",
    "    private int maxTokens;\n",
    "    private int batchSize;\n",
    "    private NDManager manager;\n",
    "    private boolean prepared;\n",
    "\n",
    "    public TimeMachineDataset(Builder builder) {\n",
    "        super(builder);\n",
    "        this.numSteps = builder.numSteps;\n",
    "        this.maxTokens = builder.maxTokens;\n",
    "        this.batchSize = builder.getSampler().getBatchSize();\n",
    "        this.manager = builder.manager;\n",
    "        this.data = this.manager.create(new Shape(0,35), DataType.INT32);\n",
    "        this.labels = this.manager.create(new Shape(0,35), DataType.INT32);\n",
    "        this.prepared = false;\n",
    "    }\n",
    "\n",
    "    @Override\n",
    "    public Record get(NDManager manager, long index) throws IOException {\n",
    "        NDArray X = data.get(new NDIndex(\"{}\", index));\n",
    "        NDArray Y = labels.get(new NDIndex(\"{}\", index));\n",
    "        return new Record(new NDList(X), new NDList(Y));\n",
    "    }\n",
    "\n",
    "    @Override\n",
    "    protected long availableSize() {\n",
    "        return data.getShape().get(0);\n",
    "    }\n",
    "\n",
    "    @Override\n",
    "    public void prepare(Progress progress) throws IOException, TranslateException {\n",
    "        if (prepared) {\n",
    "            return;\n",
    "        }\n",
    "\n",
    "        Pair<List<Integer>, Vocab> corpusVocabPair = null;\n",
    "        try {\n",
    "            corpusVocabPair = TimeMachine.loadCorpusTimeMachine(maxTokens);\n",
    "        } catch (Exception e) {\n",
    "            e.printStackTrace(); // Exception can be from unknown token type during tokenize() function.\n",
    "        }\n",
    "        List<Integer> corpus = corpusVocabPair.getKey();\n",
    "        this.vocab = corpusVocabPair.getValue();\n",
    "\n",
    "        // Start with a random offset (inclusive of `numSteps - 1`) to partition a\n",
    "        // sequence\n",
    "        int offset = new Random().nextInt(numSteps);\n",
    "        int numTokens = ((int) ((corpus.size() - offset - 1) / batchSize)) * batchSize;\n",
    "        NDArray Xs =\n",
    "                manager.create(\n",
    "                        corpus.subList(offset, offset + numTokens).stream()\n",
    "                                .mapToInt(Integer::intValue)\n",
    "                                .toArray());\n",
    "        NDArray Ys =\n",
    "                manager.create(\n",
    "                        corpus.subList(offset + 1, offset + 1 + numTokens).stream()\n",
    "                                .mapToInt(Integer::intValue)\n",
    "                                .toArray());\n",
    "        Xs = Xs.reshape(new Shape(batchSize, -1));\n",
    "        Ys = Ys.reshape(new Shape(batchSize, -1));\n",
    "        int numBatches = (int) Xs.getShape().get(1) / numSteps;\n",
    "\n",
    "        NDList xNDList = new NDList();\n",
    "        NDList yNDList = new NDList();\n",
    "        for (int i = 0; i < numSteps * numBatches; i += numSteps) {\n",
    "            NDArray X = Xs.get(new NDIndex(\":, {}:{}\", i, i + numSteps));\n",
    "            NDArray Y = Ys.get(new NDIndex(\":, {}:{}\", i, i + numSteps));\n",
    "            xNDList.add(X);\n",
    "            yNDList.add(Y);\n",
    "        }\n",
    "        this.data = NDArrays.concat(xNDList);\n",
    "        xNDList.close();\n",
    "        this.labels = NDArrays.concat(yNDList);\n",
    "        yNDList.close();\n",
    "        this.prepared = true;\n",
    "    }\n",
    "\n",
    "    public Vocab getVocab() {\n",
    "        return this.vocab;\n",
    "    }\n",
    "\n",
    "    public static final class Builder extends BaseBuilder<Builder> {\n",
    "        int numSteps;\n",
    "        int maxTokens;\n",
    "        NDManager manager;\n",
    "\n",
    "        @Override\n",
    "        protected Builder self() { return this; }\n",
    "\n",
    "        public Builder setSteps(int steps) {\n",
    "            this.numSteps = steps;\n",
    "            return this;\n",
    "        }\n",
    "\n",
    "        public Builder setMaxTokens(int maxTokens) {\n",
    "            this.maxTokens = maxTokens;\n",
    "            return this;\n",
    "        }\n",
    "\n",
    "        public Builder setManager(NDManager manager) {\n",
    "            this.manager = manager;\n",
    "            return this;\n",
    "        }\n",
    "\n",
    "        public TimeMachineDataset build() throws IOException, TranslateException {\n",
    "            TimeMachineDataset dataset = new TimeMachineDataset(this);\n",
    "            return dataset;\n",
    "        }\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Consequently we will update our code from the previous section for the functions `predictCh8`, `trainCh8`, `trainEpochCh8`, and `gradClipping` to include the dataset logic and also allow the functions to accept an `AbstractBlock` from DJL instead of just accepting `RNNModelScratch`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "/** Generate new characters following the `prefix`. */\n",
    "public static String predictCh8(\n",
    "        String prefix,\n",
    "        int numPreds,\n",
    "        Object net,\n",
    "        Vocab vocab,\n",
    "        Device device,\n",
    "        NDManager manager) {\n",
    "\n",
    "    List<Integer> outputs = new ArrayList<>();\n",
    "    outputs.add(vocab.getIdx(\"\" + prefix.charAt(0)));\n",
    "    Functions.SimpleFunction<NDArray> getInput =\n",
    "            () ->\n",
    "                    manager.create(outputs.get(outputs.size() - 1))\n",
    "                            .toDevice(device, false)\n",
    "                            .reshape(new Shape(1, 1));\n",
    "\n",
    "    if (net instanceof RNNModelScratch) {\n",
    "        RNNModelScratch castedNet = (RNNModelScratch) net;\n",
    "        NDList state = castedNet.beginState(1, device);\n",
    "\n",
    "        for (char c : prefix.substring(1).toCharArray()) { // Warm-up period\n",
    "            state = (NDList) castedNet.forward(getInput.apply(), state).getValue();\n",
    "            outputs.add(vocab.getIdx(\"\" + c));\n",
    "        }\n",
    "\n",
    "        NDArray y;\n",
    "        for (int i = 0; i < numPreds; i++) {\n",
    "            Pair<NDArray, NDList> pair = castedNet.forward(getInput.apply(), state);\n",
    "            y = pair.getKey();\n",
    "            state = pair.getValue();\n",
    "\n",
    "            outputs.add((int) y.argMax(1).reshape(new Shape(1)).getLong(0L));\n",
    "        }\n",
    "    } else {\n",
    "        AbstractBlock castedNet = (AbstractBlock) net;\n",
    "        NDList state = null;\n",
    "        for (char c : prefix.substring(1).toCharArray()) { // Warm-up period\n",
    "            if (state == null) {\n",
    "                // Begin state\n",
    "                state =\n",
    "                        castedNet\n",
    "                                .forward(\n",
    "                                        new ParameterStore(manager, false),\n",
    "                                        new NDList(getInput.apply()),\n",
    "                                        false)\n",
    "                                .subNDList(1);\n",
    "            } else {\n",
    "                state =\n",
    "                        castedNet\n",
    "                                .forward(\n",
    "                                        new ParameterStore(manager, false),\n",
    "                                        new NDList(getInput.apply()).addAll(state),\n",
    "                                        false)\n",
    "                                .subNDList(1);\n",
    "            }\n",
    "            outputs.add(vocab.getIdx(\"\" + c));\n",
    "        }\n",
    "\n",
    "        NDArray y;\n",
    "        for (int i = 0; i < numPreds; i++) {\n",
    "            NDList pair =\n",
    "                    castedNet.forward(\n",
    "                            new ParameterStore(manager, false),\n",
    "                            new NDList(getInput.apply()).addAll(state),\n",
    "                            false);\n",
    "            y = pair.get(0);\n",
    "            state = pair.subNDList(1);\n",
    "\n",
    "            outputs.add((int) y.argMax(1).reshape(new Shape(1)).getLong(0L));\n",
    "        }\n",
    "    }\n",
    "\n",
    "    StringBuilder output = new StringBuilder();\n",
    "    for (int i : outputs) {\n",
    "        output.append(vocab.idxToToken.get(i));\n",
    "    }\n",
    "    return output.toString();\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "/** Train a model. */\n",
    "public static void trainCh8(\n",
    "        Object net,\n",
    "        RandomAccessDataset dataset,\n",
    "        Vocab vocab,\n",
    "        int lr,\n",
    "        int numEpochs,\n",
    "        Device device,\n",
    "        boolean useRandomIter,\n",
    "        NDManager manager)\n",
    "        throws IOException, TranslateException {\n",
    "    SoftmaxCrossEntropyLoss loss = new SoftmaxCrossEntropyLoss();\n",
    "    Animator animator = new Animator();\n",
    "\n",
    "    Functions.voidTwoFunction<Integer, NDManager> updater;\n",
    "    if (net instanceof RNNModelScratch) {\n",
    "        RNNModelScratch castedNet = (RNNModelScratch) net;\n",
    "        updater =\n",
    "                (batchSize, subManager) ->\n",
    "                        Training.sgd(castedNet.params, lr, batchSize, subManager);\n",
    "    } else {\n",
    "        // Already initialized net\n",
    "        AbstractBlock castedNet = (AbstractBlock) net;\n",
    "        Model model = Model.newInstance(\"model\");\n",
    "        model.setBlock(castedNet);\n",
    "\n",
    "        Tracker lrt = Tracker.fixed(lr);\n",
    "        Optimizer sgd = Optimizer.sgd().setLearningRateTracker(lrt).build();\n",
    "\n",
    "        DefaultTrainingConfig config =\n",
    "                new DefaultTrainingConfig(loss)\n",
    "                        .optOptimizer(sgd) // Optimizer (loss function)\n",
    "                        .optInitializer(\n",
    "                                new NormalInitializer(0.01f),\n",
    "                                Parameter.Type.WEIGHT) // setting the initializer\n",
    "                        .optDevices(Engine.getInstance().getDevices(1)) // setting the number of GPUs needed\n",
    "                        .addEvaluator(new Accuracy()) // Model Accuracy\n",
    "                        .addTrainingListeners(TrainingListener.Defaults.logging()); // Logging\n",
    "\n",
    "        Trainer trainer = model.newTrainer(config);\n",
    "        updater = (batchSize, subManager) -> trainer.step();\n",
    "    }\n",
    "\n",
    "    Function<String, String> predict =\n",
    "            (prefix) -> predictCh8(prefix, 50, net, vocab, device, manager);\n",
    "    // Train and predict\n",
    "    double ppl = 0.0;\n",
    "    double speed = 0.0;\n",
    "    for (int epoch = 0; epoch < numEpochs; epoch++) {\n",
    "        Pair<Double, Double> pair =\n",
    "                trainEpochCh8(net, dataset, loss, updater, device, useRandomIter, manager);\n",
    "        ppl = pair.getKey();\n",
    "        speed = pair.getValue();\n",
    "        if ((epoch + 1) % 10 == 0) {\n",
    "           animator.add(epoch + 1, (float) ppl, \"ppl\");\n",
    "           animator.show();\n",
    "        }\n",
    "    }\n",
    "    System.out.format(\n",
    "            \"perplexity: %.1f, %.1f tokens/sec on %s%n\", ppl, speed, device.toString());\n",
    "    System.out.println(predict.apply(\"time traveller\"));\n",
    "    System.out.println(predict.apply(\"traveller\"));\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "/** Train a model within one epoch. */\n",
    "public static Pair<Double, Double> trainEpochCh8(\n",
    "        Object net,\n",
    "        RandomAccessDataset dataset,\n",
    "        Loss loss,\n",
    "        Functions.voidTwoFunction<Integer, NDManager> updater,\n",
    "        Device device,\n",
    "        boolean useRandomIter,\n",
    "        NDManager manager)\n",
    "        throws IOException, TranslateException {\n",
    "    StopWatch watch = new StopWatch();\n",
    "    watch.start();\n",
    "    Accumulator metric = new Accumulator(2); // Sum of training loss, no. of tokens\n",
    "\n",
    "    try (NDManager childManager = manager.newSubManager()) {\n",
    "        NDList state = null;\n",
    "        for (Batch batch : dataset.getData(childManager)) {\n",
    "            NDArray X = batch.getData().head().toDevice(device, true);\n",
    "            NDArray Y = batch.getLabels().head().toDevice(device, true);\n",
    "            if (state == null || useRandomIter) {\n",
    "                // Initialize `state` when either it is the first iteration or\n",
    "                // using random sampling\n",
    "                if (net instanceof RNNModelScratch) {\n",
    "                    state =\n",
    "                            ((RNNModelScratch) net)\n",
    "                                    .beginState((int) X.getShape().getShape()[0], device);\n",
    "                }\n",
    "            } else {\n",
    "                for (NDArray s : state) {\n",
    "                    s.stopGradient();\n",
    "                }\n",
    "            }\n",
    "            if (state != null) {\n",
    "                state.attach(childManager);\n",
    "            }\n",
    "\n",
    "            NDArray y = Y.transpose().reshape(new Shape(-1));\n",
    "            X = X.toDevice(device, false);\n",
    "            y = y.toDevice(device, false);\n",
    "            try (GradientCollector gc = Engine.getInstance().newGradientCollector()) {\n",
    "                NDArray yHat;\n",
    "                if (net instanceof RNNModelScratch) {\n",
    "                    Pair<NDArray, NDList> pairResult = ((RNNModelScratch) net).forward(X, state);\n",
    "                    yHat = pairResult.getKey();\n",
    "                    state = pairResult.getValue();\n",
    "                } else {\n",
    "                    NDList pairResult;\n",
    "                    if (state == null) {\n",
    "                        // Begin state\n",
    "                        pairResult =\n",
    "                                ((AbstractBlock) net)\n",
    "                                        .forward(\n",
    "                                                new ParameterStore(manager, false),\n",
    "                                                new NDList(X),\n",
    "                                                true);\n",
    "                    } else {\n",
    "                        pairResult =\n",
    "                                ((AbstractBlock) net)\n",
    "                                        .forward(\n",
    "                                                new ParameterStore(manager, false),\n",
    "                                                new NDList(X).addAll(state),\n",
    "                                                true);\n",
    "                    }\n",
    "                    yHat = pairResult.get(0);\n",
    "                    state = pairResult.subNDList(1);\n",
    "                }\n",
    "\n",
    "                NDArray l = loss.evaluate(new NDList(y), new NDList(yHat)).mean();\n",
    "                gc.backward(l);\n",
    "                metric.add(new float[] {l.getFloat() * y.size(), y.size()});\n",
    "            }\n",
    "            gradClipping(net, 1, childManager);\n",
    "            updater.apply(1, childManager); // Since the `mean` function has been invoked\n",
    "        }\n",
    "    }\n",
    "    return new Pair<>(Math.exp(metric.get(0) / metric.get(1)), metric.get(1) / watch.stop());\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "/** Clip the gradient. */\n",
    "public static void gradClipping(Object net, int theta, NDManager manager) {\n",
    "    double result = 0;\n",
    "    NDList params;\n",
    "    if (net instanceof RNNModelScratch) {\n",
    "        params = ((RNNModelScratch) net).params;\n",
    "    } else {\n",
    "        params = new NDList();\n",
    "        for (Pair<String, Parameter> pair : ((AbstractBlock) net).getParameters()) {\n",
    "            params.add(pair.getValue().getArray());\n",
    "        }\n",
    "    }\n",
    "    for (NDArray p : params) {\n",
    "        NDArray gradient = p.getGradient().stopGradient();\n",
    "        gradient.attach(manager);\n",
    "        result += gradient.pow(2).sum().getFloat();\n",
    "    }\n",
    "    double norm = Math.sqrt(result);\n",
    "    if (norm > theta) {\n",
    "        for (NDArray param : params) {\n",
    "            NDArray gradient = param.getGradient();\n",
    "            gradient.muli(theta / norm);\n",
    "        }\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we will leverage the dataset that we just created and assign the required parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "int batchSize = 32;\n",
    "int numSteps = 35;\n",
    "\n",
    "TimeMachineDataset dataset = new TimeMachineDataset.Builder()\n",
    "        .setManager(manager).setMaxTokens(10000).setSampling(batchSize, false)\n",
    "        .setSteps(numSteps).build();\n",
    "dataset.prepare();\n",
    "Vocab vocab = dataset.getVocab();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 3
   },
   "source": [
    "## Defining the Model\n",
    "\n",
    "High-level APIs provide implementations of recurrent neural networks.\n",
    "We construct the recurrent neural network layer `rnn_layer` with a single hidden layer and 256 hidden units.\n",
    "In fact, we have not even discussed yet what it means to have multiple layers---this will happen in :numref:`sec_deep_rnn`.\n",
    "For now, suffice it to say that multiple layers simply amount to the output of one layer of RNN being used as the input for the next layer of RNN.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "int numHiddens = 256;\n",
    "RNN rnnLayer = RNN.builder().setNumLayers(1)\n",
    "        .setStateSize(numHiddens).optReturnState(true).optBatchFirst(false).build();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 6,
    "tab": [
     "mxnet"
    ]
   },
   "source": [
    "Initializing the hidden state is straightforward.\n",
    "We invoke the member function `beginState` _(In DJL we don't have to run `beginState` to later specify the resulting state the first time we run `forward`, as this logic is ran by DJL the first time we do `forward` but we will create it here for demonstration purposes)_.\n",
    "This returns a list (`state`)\n",
    "that contains\n",
    "an initial hidden state\n",
    "for each example in the minibatch,\n",
    "whose shape is\n",
    "(number of hidden layers, batch size, number of hidden units).\n",
    "For some models \n",
    "to be introduced later \n",
    "(e.g., long short-term memory),\n",
    "such a list also\n",
    "contains other information."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "public static NDList beginState(int batchSize, int numLayers, int numHiddens) {\n",
    "    return new NDList(manager.zeros(new Shape(numLayers, batchSize, numHiddens)));\n",
    "}\n",
    "\n",
    "NDList state = beginState(batchSize, 1, numHiddens);\n",
    "System.out.println(state.size());\n",
    "System.out.println(state.get(0).getShape());"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 10
   },
   "source": [
    "With a hidden state and an input,\n",
    "we can compute the output with\n",
    "the updated hidden state.\n",
    "It should be emphasized that\n",
    "the \"output\" (`Y`) of `rnnLayer`\n",
    "does *not* involve computation of output layers:\n",
    "it refers to \n",
    "the hidden state at *each* time step,\n",
    "and they can be used as the input\n",
    "to the subsequent output layer."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 11,
    "tab": [
     "mxnet"
    ]
   },
   "source": [
    "Besides,\n",
    "the updated hidden state (`stateNew`) returned by `rnnLayer`\n",
    "refers to the hidden state\n",
    "at the *last* time step of the minibatch.\n",
    "It can be used to initialize the \n",
    "hidden state for the next minibatch within an epoch\n",
    "in sequential partitioning.\n",
    "For multiple hidden layers,\n",
    "the hidden state of each layer will be stored\n",
    "in this variable (`stateNew`).\n",
    "For some models \n",
    "to be introduced later \n",
    "(e.g., long short-term memory),\n",
    "this variable also\n",
    "contains other information."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "NDArray X = manager.randomUniform (0, 1,new Shape(numSteps, batchSize, vocab.length()));\n",
    "\n",
    "NDList input = new NDList(X, state.get(0));\n",
    "rnnLayer.initialize(manager, DataType.FLOAT32, input.getShapes());\n",
    "NDList forwardOutput = rnnLayer.forward(new ParameterStore(manager, false), input, false);\n",
    "NDArray Y = forwardOutput.get(0);\n",
    "NDArray stateNew = forwardOutput.get(1);\n",
    "\n",
    "System.out.println(Y.getShape());\n",
    "System.out.println(stateNew.getShape());"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 14
   },
   "source": [
    "Similar to :numref:`sec_rnn_scratch`,\n",
    "we define an `RNNModel` class \n",
    "for a complete RNN model.\n",
    "Note that `rnnLayer` only contains the hidden recurrent layers, we need to create a separate output layer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "public class RNNModel extends AbstractBlock {\n",
    "\n",
    "    private RNN rnnLayer;\n",
    "    private Linear dense;\n",
    "    private int vocabSize;\n",
    "\n",
    "    public RNNModel(RNN rnnLayer, int vocabSize) {\n",
    "        this.rnnLayer = rnnLayer;\n",
    "        this.addChildBlock(\"rnn\", rnnLayer);\n",
    "        this.vocabSize = vocabSize;\n",
    "        this.dense = Linear.builder().setUnits(vocabSize).build();\n",
    "        this.addChildBlock(\"linear\", dense);\n",
    "    }\n",
    "\n",
    "    @Override\n",
    "    protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, PairList<String, Object> params) {\n",
    "        NDArray X = inputs.get(0).transpose().oneHot(vocabSize);\n",
    "        inputs.set(0, X);\n",
    "        NDList result = rnnLayer.forward(parameterStore, inputs, training);\n",
    "        NDArray Y = result.get(0);\n",
    "        NDArray state = result.get(1);\n",
    "\n",
    "        int shapeLength = Y.getShape().dimension();\n",
    "        NDList output = dense.forward(parameterStore, new NDList(Y\n",
    "                .reshape(new Shape(-1, Y.getShape().get(shapeLength-1)))), training);\n",
    "        return new NDList(output.get(0), state);\n",
    "    }\n",
    "    \n",
    "    @Override\n",
    "    public void initializeChildBlocks(NDManager manager, DataType dataType, Shape... inputShapes) {\n",
    "        Shape shape = rnnLayer.getOutputShapes(new Shape[]{inputShapes[0]})[0];\n",
    "        dense.initialize(manager, dataType, new Shape(vocabSize, shape.get(shape.dimension() - 1)));\n",
    "    }\n",
    "\n",
    "    /* We won't implement this since we won't be using it but it's required as part of an AbstractBlock  */\n",
    "    @Override\n",
    "    public Shape[] getOutputShapes(Shape[] inputShapes) {\n",
    "        return new Shape[0];\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 17
   },
   "source": [
    "## Training and Predicting\n",
    "\n",
    "Before training the model, let us make a prediction with the a model that has random weights.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Device device = manager.getDevice();\n",
    "RNNModel net = new RNNModel(rnnLayer, vocab.length());\n",
    "net.initialize(manager, DataType.FLOAT32, X.getShape());\n",
    "predictCh8(\"time traveller\", 10, net, vocab, device, manager);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 20
   },
   "source": [
    "As is quite obvious, this model does not work at all. Next, we call `trainCh8` with the same hyperparameters defined in :numref:`sec_rnn_scratch` and train our model with high-level APIs.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "int numEpochs = Integer.getInteger(\"MAX_EPOCH\", 500);\n",
    "\n",
    "int lr = 1;\n",
    "trainCh8((Object) net, dataset, vocab, lr, numEpochs, device, false, manager);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 22
   },
   "source": [
    "Compared with the last section, this model achieves comparable perplexity,\n",
    "albeit within a shorter period of time, due to the code being more optimized by\n",
    "high-level APIs of the deep learning framework.\n",
    "\n",
    "\n",
    "## Summary\n",
    "\n",
    "* High-level APIs of the deep learning framework provides an implementation of the RNN layer.\n",
    "* The RNN layer of high-level APIs returns an output and an updated hidden state, where the output does not involve output layer computation.\n",
    "* Using high-level APIs leads to faster RNN training than using its implementation from scratch.\n",
    "\n",
    "## Exercises\n",
    "\n",
    "1. Can you make the RNN model overfit using the high-level APIs?\n",
    "1. What happens if you increase the number of hidden layers in the RNN model? Can you make the model work?\n",
    "1. Implement the autoregressive model of :numref:`sec_sequence` using an RNN.\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Java",
   "language": "java",
   "name": "java"
  },
  "language_info": {
   "codemirror_mode": "java",
   "file_extension": ".jshell",
   "mimetype": "text/x-java-source",
   "name": "Java",
   "pygments_lexer": "java",
   "version": "14.0.2+12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
